{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fc5bde60-1899-461d-8083-3ee04ac7c099",
   "metadata": {},
   "source": [
    "# 模型推理 - 使用 QLoRA 微调后的 ChatGLM3-6B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3292b88c-91f0-48d2-91a5-06b0830c7e70",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig\n",
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "# 定义全局变量和参数\n",
    "model_name_or_path = 'THUDM/chatglm3-6b'  # 模型ID或本地路径\n",
    "peft_model_path = f\"models/demo/{model_name_or_path}\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9f81454c-24b2-4072-ab05-b25f9b120ae6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a6a3dddcd9df4715a4b693559cf30cff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "config = PeftConfig.from_pretrained(peft_model_path)\n",
    "\n",
    "q_config = BitsAndBytesConfig(load_in_4bit=True,\n",
    "                              bnb_4bit_quant_type='nf4',\n",
    "                              bnb_4bit_use_double_quant=True,\n",
    "                              bnb_4bit_compute_dtype=torch.float32)\n",
    "\n",
    "base_model = AutoModel.from_pretrained(config.base_model_name_or_path,\n",
    "                                       quantization_config=q_config,\n",
    "                                       trust_remote_code=True,\n",
    "                                       device_map='auto')\n",
    "base_model.requires_grad_(False)\n",
    "base_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "342b3659-d644-4232-8af1-f092e733bf40",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "6d23e720-dee1-4b43-a298-0cbe1d8ad11d",
   "metadata": {},
   "source": [
    "## 微调前后效果对比\n",
    "\n",
    "### ChatGLM-6B\n",
    "\n",
    "```\n",
    "输入：\n",
    "\n",
    "类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领\n",
    "\n",
    "ChatGLM-6B 微调前输出：\n",
    "\n",
    "* 版型：修身\n",
    "* 显瘦：True\n",
    "* 风格：文艺\n",
    "* 简约：True\n",
    "* 图案：印花\n",
    "* 撞色：True\n",
    "* 裙下摆：直筒或微喇\n",
    "* 裙长：中长裙\n",
    "* 连衣裙：True\n",
    "\n",
    "ChatGLM-6B 微调后输出：\n",
    "\n",
    "一款简约而不简单的连衣裙，采用撞色的印花点缀，打造文艺气息，简约的圆领，修饰脸型。衣袖和裙摆的压褶，增添设计感，修身的版型，勾勒出窈窕的身材曲线。\n",
    "```\n",
    "\n",
    "### ChatGLM2-6B\n",
    "\n",
    "```\n",
    "输入：\n",
    "类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领\n",
    "\n",
    "微调前：\n",
    "这款裙子,版型显瘦,采用简约文艺风格,图案为印花和撞色设计,裙下摆为压褶裙摆,裙长为连衣裙,适合各种场合穿着,让你舒适自在。圆领设计,优雅清新,让你在任何场合都充满自信。如果你正在寻找一款舒适、时尚、优雅的裙子,不妨 考虑这款吧!\n",
    "\n",
    "微调后: \n",
    "这款连衣裙简约的设计，撞色印花点缀，丰富了视觉，上身更显时尚。修身的版型，贴合身形，穿着舒适不束缚。圆领的设计，露出精致锁骨，尽显女性优雅气质。下摆压褶的设计，增添立体感，行走间更显飘逸。前短后长的设计，显 得身材比例更加完美。文艺的碎花设计，更显精致。\n",
    "```\n",
    "\n",
    "### ChatGLM3-6B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9d7757a4-7d1f-488f-8d80-b73dfa4863d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "输入：\n",
      "类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领\n"
     ]
    }
   ],
   "source": [
    "input_text = '类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领'\n",
    "print(f'输入：\\n{input_text}')\n",
    "tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2d30fce1-e01f-4303-aa55-ed004eaa22a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ChatGLM3-6B 微调前：\n",
      "连衣裙是女孩子们最爱的单品之一，这款连衣裙采用撞色圆领设计，简洁大方，修饰脸型。衣身采用印花图案点缀，展现出优雅文艺的气质。袖子采用压褶设计，轻松遮肉显瘦，修饰臂部线条。简约的款式设计，穿着舒适，轻松百搭。\n"
     ]
    }
   ],
   "source": [
    "response, history = base_model.chat(tokenizer=tokenizer, query=input_text)\n",
    "print(f'ChatGLM3-6B 微调前：\\n{response}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "38b5a770-baef-4697-bb71-6088e3a43d59",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ChatGLM3-6B 微调后: \n",
      "这款连衣裙简约的圆领设计，凸显出修长的颈部线条，而衣身采用撞色印花设计，更显俏皮甜美，而袖口和裙摆处加入压褶设计，增添层次感，更具艺术感，而裙身整体设计简约大气，彰显出优雅文艺的气质。\n"
     ]
    }
   ],
   "source": [
    "model = PeftModel.from_pretrained(base_model, peft_model_path)\n",
    "response, history = model.chat(tokenizer=tokenizer, query=input_text)\n",
    "print(f'ChatGLM3-6B 微调后: \\n{response}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cf454e0-f0f5-4fb0-ab90-83e9615f132a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
